{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.5 读取和存储"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.4.1\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.5.1 读写`Tensor`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "x = torch.ones(3)\n",
    "torch.save(x, 'x.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1., 1., 1.])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x2 = torch.load('x.pt')\n",
    "x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = torch.zeros(4)\n",
    "torch.save([x, y], 'xy.pt')\n",
    "xy_list = torch.load('xy.pt')\n",
    "xy_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.save({'x': x, 'y': y}, 'xy_dict.pt')\n",
    "xy = torch.load('xy_dict.pt')\n",
    "xy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.5.2 读写模型\n",
    "### 4.5.2.1 `state_dict`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('hidden.weight', tensor([[ 0.1836, -0.1812, -0.1681],\n",
       "                      [ 0.0406,  0.3061,  0.4599]])),\n",
       "             ('hidden.bias', tensor([-0.3384,  0.1910])),\n",
       "             ('output.weight', tensor([[0.0380, 0.4919]])),\n",
       "             ('output.bias', tensor([0.1451]))])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MLP, self).__init__()\n",
    "        self.hidden = nn.Linear(3, 2)\n",
    "        self.act = nn.ReLU()\n",
    "        self.output = nn.Linear(2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        a = self.act(self.hidden(x))\n",
    "        return self.output(a)\n",
    "\n",
    "net = MLP()\n",
    "net.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'param_groups': [{'dampening': 0,\n",
       "   'lr': 0.001,\n",
       "   'momentum': 0.9,\n",
       "   'nesterov': False,\n",
       "   'params': [4624483024, 4624484608, 4624484680, 4624484752],\n",
       "   'weight_decay': 0}],\n",
       " 'state': {}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "optimizer.state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4.5.2.2 保存和加载模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1],\n",
       "        [1]], dtype=torch.uint8)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.randn(2, 3)\n",
    "Y = net(X)\n",
    "\n",
    "PATH = \"./net.pt\"\n",
    "torch.save(net.state_dict(), PATH)\n",
    "\n",
    "net2 = MLP()\n",
    "net2.load_state_dict(torch.load(PATH))\n",
    "Y2 = net2(X)\n",
    "Y2 == Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
